import logging
from typing import List, Dict, Any, Optional
from app.base.logger import setup_logger
from app.service.nebula_service import NebulaClient
from app.service.milvus_service import (
    MilvusService,
    MilvusConnectionConfig,
    MilvusFieldConfig
)
from app.service.embedding_service import embedding_service
from app.utils.content_processor import ContentProcessor
from app.utils.config_manager import config

# 移除旧的字段常量导入

# 使用新的日志配置
logger = setup_logger("knowledge_graph_service")

class KnowledgeGraphService:
    """
    知识图谱查询服务，整合向量搜索和图数据库查询

    要求上游传入Milvus配置
    """
    def __init__(self,
                 milvus_connection_config: MilvusConnectionConfig,
                 milvus_field_config: MilvusFieldConfig):
        """
        初始化知识图谱服务

        Args:
            milvus_connection_config: Milvus连接配置（必需）
            milvus_field_config: Milvus字段配置（必需）
        """
        self.nebula_client = None
        self.milvus_service = None
        self.initialized = False
        self.error_message = None

        # 保存Milvus配置
        self.milvus_connection_config = milvus_connection_config
        self.milvus_field_config = milvus_field_config

    def initialize(self):
        """
        初始化知识图谱服务
        
        Returns:
            bool: 初始化是否成功
        """
        try:
            logger.info("初始化知识图谱服务")
            
            # 初始化NebulaGraph客户端
            self.nebula_client = NebulaClient()
            nebula_success = self.nebula_client.connect(
                ip=config.nebula_config.get("ip"),
                port=config.nebula_config.get("port"),
                user=config.nebula_config.get("user"),
                password=config.nebula_config.get("password"),
                space_name=config.nebula_config.get("space")
            )
            
            if not nebula_success:
                error_msg = self.nebula_client.get_error_message()
                logger.error(f"连接NebulaGraph失败: {error_msg}")
                self.error_message = f"NebulaGraph连接失败: {error_msg}"
                return False
            
            # 初始化Milvus服务
            self.milvus_service = MilvusService(
                connection_config=self.milvus_connection_config,
                field_config=self.milvus_field_config
            )

            milvus_success = self.milvus_service.connect()
            
            if not milvus_success:
                error_msg = self.milvus_service.get_error_message()
                logger.error(f"连接Milvus失败: {error_msg}")
                self.error_message = f"Milvus连接失败: {error_msg}"
                return False
            
            # 初始化嵌入服务
            embedding_success = embedding_service.initialize()
            if not embedding_success:
                error_msg = embedding_service.get_error_message()
                logger.error(f"初始化嵌入服务失败: {error_msg}")
                self.error_message = f"嵌入服务初始化失败: {error_msg}"
                return False
            
            self.initialized = True
            self.error_message = None
            logger.info("知识图谱服务初始化成功")
            return True
            
        except Exception as e:
            logger.exception(f"初始化知识图谱服务失败: {str(e)}")
            self.error_message = f"初始化异常: {str(e)}"
            self.initialized = False
            return False
    
    def search_knowledge_graph(self, query_text: str, repo_id: str, top_k: int = 3, related_limit: int = 3):
        """
        搜索知识图谱，包括向量搜索和图数据库查询
        
        Args:
            query_text: 查询文本
            repo_id: 仓库ID，用于过滤结果
            top_k: 每种节点类型返回的最大结果数
            related_limit: 关联节点的最大数量
        
        Returns:
            Dict: 包含搜索结果的字典
        """
        if not self.initialized:
            logger.error("知识图谱服务未初始化，无法执行搜索")
            return {"error": "知识图谱服务未初始化"}
        
        try:
            # 生成查询向量
            query_vector = embedding_service.embed_text(query_text)
            if not query_vector:
                return {"error": "无法生成查询向量"}
            
            # 查询特定类型节点的向量相似度，添加comment和class节点类型
            node_types = ["function", "annotations", "comment", "class"]
            
            # 对每种节点类型分别进行向量搜索
            grouped_results = {}
            
            logger.info(f"执行向量搜索: 查询={query_text[:50]}..., 仓库ID={repo_id}, 节点类型={node_types}")
            
            for node_type in node_types:
                # 为每种节点类型构建过滤条件
                filter_expr = f"node_type == '{node_type}' AND repo_id == '{repo_id}'"

                # 对当前节点类型进行混合搜索
                logger.debug(f"搜索节点类型: {node_type}")
                vector_results = self.milvus_service.hybrid_search(
                    query_text=query_text,
                    query_dense_vector=query_vector,
                    limit=top_k,  # 每种节点类型返回top_k个
                    filter_expr=filter_expr,
                    distance_threshold=0.7,  # 过滤相似度低于0.7的结果
                )

                # 初始化当前节点类型的结果列表
                grouped_results[node_type] = []

                # 处理当前节点类型的搜索结果
                if vector_results and len(vector_results) > 0 and len(vector_results[0]) > 0:
                    for item in vector_results[0]:
                        entity = item.get("entity", {})
                        # 获取相似度分数，默认为0.5如果无法获取
                        similarity = item.get("distance")
                        entity["rank"] = similarity
                        grouped_results[node_type].append(entity)

                        # 确保不超过top_k个结果
                        if len(grouped_results[node_type]) >= top_k:
                            break

                logger.debug(f"节点类型 {node_type} 搜索到 {len(grouped_results[node_type])} 个结果")
            
            # 检查是否有任何搜索结果
            total_results = sum(len(results) for results in grouped_results.values())
            if total_results == 0:
                logger.warning("向量搜索未返回结果")
                return {"error": "向量搜索未返回结果，请检查repo_id是否正确"}

            logger.info(f"向量搜索结果分类完成: {', '.join([f'{k}: {len(v)}个' for k, v in grouped_results.items()])}")
            
            # 存储所有节点的集合(用于去重)
            all_nodes = {}
            
            # 处理搜索到的节点
            for node_type, nodes in grouped_results.items():
                for node in nodes:
                    # 确保节点有ID
                    if "id" in node:
                        node_id = node["id"]
                        if node_id not in all_nodes:
                            all_nodes[node_id] = {
                                "id": node_id,
                                "node_type": node.get("node_type", ""),
                                "full_name": node.get("full_name", ""),
                                "digest": node.get("digest", ""),
                                "rank": node.get("rank")  # 使用从查询结果中提取的相似度
                            }
            
            # 查询相关节点
            related_nodes = {}
            
            # 用于存储类摘要节点
            class_summary_nodes = {}
            
            # 用于记录类中已包含的方法ID，避免重复
            excluded_method_ids = set()
            
            # 1. 对于注解节点，查询使用该注解的方法
            for node in grouped_results.get("annotations", []):
                node_id = node.get("id", "")
                if not node_id:
                    continue
                    
                # 查询使用此注解的方法，使用CASE语句处理不同tag的repo_id
                query = f"""
                MATCH (a:annotations)-[r:contains]-(f:function)
                WHERE id(a) == "{node_id}" AND 
                     CASE 
                         WHEN a.annotations.repo_id IS NOT NULL THEN a.annotations.repo_id == "{repo_id}"
                         WHEN a.repo_id IS NOT NULL THEN a.repo_id == "{repo_id}"
                         ELSE false
                     END AND
                     CASE 
                         WHEN f.function.repo_id IS NOT NULL THEN f.function.repo_id == "{repo_id}"
                         WHEN f.repo_id IS NOT NULL THEN f.repo_id == "{repo_id}"
                         ELSE false
                     END
                RETURN id(f) as func_id, f.function.name, f.function.full_name, 
                       f.function.type, f.function.visibility, f.function.content
                LIMIT {related_limit}
                """
                
                logger.debug(f"查询注解相关方法: 注解ID={node_id}, 仓库ID={repo_id}")
                result = self.nebula_client.execute_query(query)
                
                if result and result.rows():
                    for record in result.rows():
                        values = record.values
                        if len(values) >= 6:
                            func_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                            
                            if func_id not in all_nodes and func_id not in related_nodes:
                                related_nodes[func_id] = {
                                    "id": func_id,
                                    "node_type": "function",
                                    "name": values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value,
                                    "full_name": values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value,
                                    "type": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                                    "visibility": values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value,
                                    "content": values[5].value.decode('utf-8') if isinstance(values[5].value, bytes) else values[5].value,
                                    "relation": "uses_annotation",
                                    "parent_id": node_id
                                }
            
            # 2. 对于方法节点，查询它调用的其他方法
            for node in grouped_results.get("function", []):
                node_id = node.get("id", "")
                if not node_id:
                    continue
                    
                # 查询此方法调用的其他方法，使用CASE语句处理不同tag的repo_id
                query = f"""
                MATCH (f1:function)-[r:calls]->(f2:function)
                WHERE id(f1) == "{node_id}" AND 
                     CASE 
                         WHEN f1.function.repo_id IS NOT NULL THEN f1.function.repo_id == "{repo_id}"
                         WHEN f1.repo_id IS NOT NULL THEN f1.repo_id == "{repo_id}"
                         ELSE false
                     END AND
                     CASE 
                         WHEN f2.function.repo_id IS NOT NULL THEN f2.function.repo_id == "{repo_id}"
                         WHEN f2.repo_id IS NOT NULL THEN f2.repo_id == "{repo_id}"
                         ELSE false
                     END
                RETURN id(f2) as called_id, f2.function.name, f2.function.full_name, 
                       f2.function.type, f2.function.visibility, f2.function.content
                LIMIT {related_limit}
                """
                
                logger.debug(f"查询方法调用关系: 方法ID={node_id}, 仓库ID={repo_id}")
                result = self.nebula_client.execute_query(query)
                
                if result and result.rows():
                    for record in result.rows():
                        values = record.values
                        if len(values) >= 6:
                            called_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                            
                            if called_id not in all_nodes and called_id not in related_nodes:
                                related_nodes[called_id] = {
                                    "id": called_id,
                                    "node_type": "function",
                                    "name": values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value,
                                    "full_name": values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value,
                                    "type": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                                    "visibility": values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value,
                                    "content": values[5].value.decode('utf-8') if isinstance(values[5].value, bytes) else values[5].value,
                                    "relation": "called_by",
                                    "parent_id": node_id
                                }
            
            # 3. 对于注释节点，查询被该注释文档化的方法和类，如果是类则转换为其包含的方法
            for node in grouped_results.get("comment", []):
                node_id = node.get("id", "")
                if not node_id:
                    continue
                    
                # 查询被此注释文档化的方法和类
                query = f"""
                MATCH (c:comment)<-[r:documented_by]-(n)
                WHERE id(c) == "{node_id}" AND 
                     CASE 
                         WHEN c.comment.repo_id IS NOT NULL THEN c.comment.repo_id == "{repo_id}"
                         WHEN c.repo_id IS NOT NULL THEN c.repo_id == "{repo_id}"
                         ELSE false
                     END AND
                     CASE 
                         WHEN n.function.repo_id IS NOT NULL THEN n.function.repo_id == "{repo_id}"
                         WHEN n.class.repo_id IS NOT NULL THEN n.class.repo_id == "{repo_id}"
                         WHEN n.repo_id IS NOT NULL THEN n.repo_id == "{repo_id}"
                         ELSE false
                     END
                RETURN id(n) as documented_id, labels(n)[0] as node_type, 
                       CASE labels(n)[0]
                           WHEN 'function' THEN n.function.name
                           WHEN 'class' THEN n.class.name
                           ELSE ''
                       END as name,
                       CASE labels(n)[0]
                           WHEN 'function' THEN n.function.full_name
                           WHEN 'class' THEN n.class.full_name
                           ELSE ''
                       END as full_name,
                       CASE labels(n)[0]
                           WHEN 'function' THEN n.function.type
                           WHEN 'class' THEN n.class.type
                           ELSE ''
                       END as type,
                       CASE labels(n)[0]
                           WHEN 'function' THEN n.function.visibility
                           WHEN 'class' THEN n.class.visibility
                           ELSE ''
                       END as visibility,
                       CASE labels(n)[0]
                           WHEN 'function' THEN n.function.content
                           WHEN 'class' THEN n.class.content
                           ELSE ''
                       END as content
                LIMIT {related_limit}
                """
                
                logger.debug(f"查询注释文档化的节点: 注释ID={node_id}, 仓库ID={repo_id}")
                result = self.nebula_client.execute_query(query)
                
                if result and result.rows():
                    for record in result.rows():
                        values = record.values
                        if len(values) >= 7:
                            documented_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                            documented_node_type = values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value
                            
                            # 如果是方法节点，直接添加
                            if documented_node_type == "function":
                                if documented_id not in all_nodes and documented_id not in related_nodes:
                                    related_nodes[documented_id] = {
                                        "id": documented_id,
                                        "node_type": "function",
                                        "name": values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value,
                                        "full_name": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                                        "type": values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value,
                                        "visibility": values[5].value.decode('utf-8') if isinstance(values[5].value, bytes) else values[5].value,
                                        "content": values[6].value.decode('utf-8') if isinstance(values[6].value, bytes) else values[6].value,
                                        "relation": "documented_by_comment",
                                        "parent_id": node_id
                                    }
                            # 如果是类节点，生成类摘要（添加到类摘要列表中，稍后统一处理）
                            elif documented_node_type == "class":
                                if documented_id not in class_summary_nodes:
                                    class_summary_nodes[documented_id] = {
                                        "id": documented_id,
                                        "full_name": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                                        "content": values[6].value.decode('utf-8') if isinstance(values[6].value, bytes) else values[6].value,
                                        "relation": "documented_by_comment",
                                        "parent_id": node_id
                                    }
            
            # 4. 对于类节点，生成类摘要
            for node in grouped_results.get("class", []):
                node_id = node.get("id", "")
                if not node_id:
                    continue
                
                if node_id not in class_summary_nodes:
                    class_summary_nodes[node_id] = {
                        "id": node_id,
                        "full_name": node.get("full_name", ""),
                        "content": node.get("content", ""),
                        "relation": "direct_match",
                        "parent_id": None
                    }
            
            # 5. 为所有类节点生成完整的类摘要结构
            for class_id, class_data in class_summary_nodes.items():
                # 查询类的依赖关系
                depends_query = f"""
                MATCH (c:class)-[r:depends_on]->(dep:class)
                WHERE id(c) == "{class_id}" AND 
                     CASE 
                         WHEN c.class.repo_id IS NOT NULL THEN c.class.repo_id == "{repo_id}"
                         WHEN c.repo_id IS NOT NULL THEN c.repo_id == "{repo_id}"
                         ELSE false
                     END AND
                     CASE 
                         WHEN dep.class.repo_id IS NOT NULL THEN dep.class.repo_id == "{repo_id}"
                         WHEN dep.repo_id IS NOT NULL THEN dep.repo_id == "{repo_id}"
                         ELSE false
                     END
                RETURN id(dep) as dep_id, dep.class.full_name, r.dependency_type
                """
                
                logger.debug(f"查询类依赖关系: 类ID={class_id}, 仓库ID={repo_id}")
                depends_result = self.nebula_client.execute_query(depends_query)
                
                depends_on = []
                if depends_result and depends_result.rows():
                    for record in depends_result.rows():
                        values = record.values
                        if len(values) >= 3:
                            dep_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                            dep_full_name = values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value
                            dependency_type = values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value
                            
                            depends_on.append({
                                "node_id": dep_id,
                                "full_name": dep_full_name,
                                "dependency_type": dependency_type
                            })
                
                # 查询类包含的方法
                methods_query = f"""
                MATCH (c:class)-[r:contains]->(f:function)
                WHERE id(c) == "{class_id}" AND 
                     CASE 
                         WHEN c.class.repo_id IS NOT NULL THEN c.class.repo_id == "{repo_id}"
                         WHEN c.repo_id IS NOT NULL THEN c.repo_id == "{repo_id}"
                         ELSE false
                     END AND
                     CASE 
                         WHEN f.function.repo_id IS NOT NULL THEN f.function.repo_id == "{repo_id}"
                         WHEN f.repo_id IS NOT NULL THEN f.repo_id == "{repo_id}"
                         ELSE false
                     END
                RETURN id(f) as method_id, f.function.name, f.function.full_name, 
                       f.function.visibility, f.function.content
                """
                
                logger.debug(f"查询类包含的方法: 类ID={class_id}, 仓库ID={repo_id}")
                methods_result = self.nebula_client.execute_query(methods_query)
                
                function_list = []
                if methods_result and methods_result.rows():
                    for record in methods_result.rows():
                        values = record.values
                        if len(values) >= 5:
                            method_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                            method_name = values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value
                            method_full_name = values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value
                            method_visibility = values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value
                            method_content = values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value
                            
                            # 查询方法的注释
                            method_comment_query = f"""
                            MATCH (f:function)-[r:documented_by]->(c:comment)
                            WHERE id(f) == "{method_id}" AND 
                                 CASE 
                                     WHEN f.function.repo_id IS NOT NULL THEN f.function.repo_id == "{repo_id}"
                                     WHEN f.repo_id IS NOT NULL THEN f.repo_id == "{repo_id}"
                                     ELSE false
                                 END AND
                                 CASE 
                                     WHEN c.comment.repo_id IS NOT NULL THEN c.comment.repo_id == "{repo_id}"
                                     WHEN c.repo_id IS NOT NULL THEN c.repo_id == "{repo_id}"
                                     ELSE false
                                 END
                            RETURN c.comment.content
                            LIMIT 1
                            """
                            
                            method_comment_result = self.nebula_client.execute_query(method_comment_query)
                            method_comment = ""
                            if method_comment_result and method_comment_result.rows() and len(method_comment_result.rows()) > 0:
                                comment_value = method_comment_result.rows()[0].values[0]
                                method_comment = comment_value.value.decode('utf-8') if isinstance(comment_value.value, bytes) else comment_value.value
                                if method_comment:
                                    method_comment = ContentProcessor.decompress(method_comment)
                            
                            # 生成方法摘要
                            method_summary = method_full_name
                            if method_comment:
                                method_summary = f"{method_full_name}\n{method_comment}"
                            
                            function_list.append({
                                "node_id": method_id,
                                "full_name": method_full_name,
                                "summary": method_summary,
                                "visibility": method_visibility
                            })
                            
                            # 记录此方法ID，避免重复包含
                            excluded_method_ids.add(method_id)
                
                # 创建类摘要节点
                related_nodes[class_id] = {
                    "id": class_id,
                    "node_type": "class_summary",
                    "full_name": class_data["full_name"],
                    "relation": class_data["relation"],
                    "parent_id": class_data["parent_id"],
                    "class_info": {
                        "depends_on": depends_on,
                        "function_list": function_list
                    }
                }
            
            logger.info(f"查询到 {len(related_nodes)} 个相关节点，其中 {len(class_summary_nodes)} 个类摘要节点")
            logger.info(f"类中包含的方法数量: {len(excluded_method_ids)} 个，将从独立方法列表中排除")

            # 构建层级化结果结构
            # 第一层：向量匹配的节点
            primary_nodes = {}

            # 1. 处理向量匹配的非类节点（排除被类包含的方法）
            for node_id, node in all_nodes.items():
                if node.get("node_type") != "class" and node_id not in excluded_method_ids:
                    primary_nodes[node_id] = {
                        "id": node_id,
                        "node_type": node.get("node_type", ""),
                        "full_name": node.get("full_name", ""),
                        "digest": node.get("digest", ""),
                        "rank": node.get("rank"),
                        "is_direct_match": True,
                        "related_nodes": []
                    }

            # 2. 处理向量匹配的类节点（转换为class_summary并作为第一层节点）
            for node_id, node in all_nodes.items():
                if node.get("node_type") == "class":
                    # 查找对应的class_summary节点
                    if node_id in related_nodes and related_nodes[node_id].get("node_type") == "class_summary":
                        class_summary = related_nodes[node_id]
                        primary_nodes[node_id] = {
                            "id": node_id,
                            "node_type": "class_summary",
                            "full_name": class_summary.get("full_name", ""),
                            "rank": node.get("rank"),  # 使用原始向量匹配的相似度
                            "is_direct_match": True,
                            "class_info": class_summary.get("class_info", {}),
                            "related_nodes": []
                        }

            # 3. 将关联节点分配到对应的第一层节点下
            for node_id, node in related_nodes.items():
                parent_id = node.get("parent_id")

                # 跳过已经作为第一层节点的class_summary
                if node.get("node_type") == "class_summary" and node_id in primary_nodes:
                    continue

                # 跳过被类包含的方法节点
                if node.get("node_type") == "function" and node_id in excluded_method_ids:
                    logger.debug(f"跳过重复的方法节点: {node_id}, 已包含在类摘要中")
                    continue

                # 如果有parent_id且parent存在于第一层节点中，则作为关联节点
                if parent_id and parent_id in primary_nodes:
                    related_node_info = {
                        "id": node_id,
                        "node_type": node.get("node_type", ""),
                        "full_name": node.get("full_name", ""),
                        "digest": node.get("digest", ""),
                        "relation": node.get("relation", ""),
                        "is_direct_match": False
                    }

                    # 处理class_summary类型的关联节点
                    if node.get("node_type") == "class_summary":
                        related_node_info["class_info"] = node.get("class_info", {})

                    primary_nodes[parent_id]["related_nodes"].append(related_node_info)

                # 如果没有parent_id或parent不存在，可能是独立的关联节点，暂时忽略
                # 这种情况在当前的查询逻辑中应该不会出现

            # 4. 查询所有节点的注释
            all_node_ids = list(primary_nodes.keys())
            for primary_node in primary_nodes.values():
                for related_node in primary_node["related_nodes"]:
                    all_node_ids.append(related_node["id"])

            node_comments = {}
            for node_id in all_node_ids:
                # 跳过class_summary节点，因为其方法注释已在生成时处理
                node = primary_nodes.get(node_id)
                if node and node.get("node_type") == "class_summary":
                    continue

                # 检查是否是关联节点中的class_summary
                is_class_summary = False
                for primary_node in primary_nodes.values():
                    for related_node in primary_node["related_nodes"]:
                        if related_node["id"] == node_id and related_node.get("node_type") == "class_summary":
                            is_class_summary = True
                            break
                    if is_class_summary:
                        break

                if is_class_summary:
                    continue

                # 查询节点的注释，使用CASE语句处理不同tag的repo_id
                query = f"""
                MATCH (n)-[r:documented_by]->(c:comment)
                WHERE id(n) == "{node_id}" AND
                     CASE
                         WHEN n.function.repo_id IS NOT NULL THEN n.function.repo_id == "{repo_id}"
                         WHEN n.annotations.repo_id IS NOT NULL THEN n.annotations.repo_id == "{repo_id}"
                         WHEN n.comment.repo_id IS NOT NULL THEN n.comment.repo_id == "{repo_id}"
                         WHEN n.repo_id IS NOT NULL THEN n.repo_id == "{repo_id}"
                         ELSE false
                     END AND
                     CASE
                         WHEN c.comment.repo_id IS NOT NULL THEN c.comment.repo_id == "{repo_id}"
                         WHEN c.repo_id IS NOT NULL THEN c.repo_id == "{repo_id}"
                         ELSE false
                     END
                RETURN id(c) as comment_id, c.comment.content
                LIMIT 1
                """

                result = self.nebula_client.execute_query(query)

                if result and result.rows() and len(result.rows()) > 0:
                    values = result.rows()[0].values
                    if len(values) >= 2:
                        comment_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                        comment_content = values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value

                        # 解压注释内容
                        if comment_content:
                            comment_content = ContentProcessor.decompress(comment_content)

                        node_comments[node_id] = {
                            "id": comment_id,
                            "content": comment_content
                        }

            logger.info(f"查询到 {len(node_comments)} 个节点注释")

            # 5. 生成最终的层级化结果
            def generate_node_summary(node_info, node_id):
                """生成节点摘要"""
                node_type = node_info.get("node_type", "")

                if node_type == "class_summary":
                    # 类摘要节点，无需summary字段
                    return {
                        "id": node_id,
                        "node_type": node_type,
                        "full_name": node_info.get("full_name", ""),
                        "is_direct_match": node_info.get("is_direct_match", False),
                        "class_info": node_info.get("class_info", {})
                    }
                else:
                    # 其他节点类型，生成摘要
                    summary_text = ""
                    if node_type == "annotations":
                        # 注解的摘要就是内容
                        content = node_info.get("digest", "")
                        summary_text = content
                    elif node_type == "comment":
                        # 注释的摘要就是内容
                        content = node_info.get("digest", "")
                        summary_text = content
                    else:
                        # 方法的摘要是全名+注释
                        full_name = node_info.get("full_name", "")
                        comment = node_comments.get(node_id, {}).get("content", "")

                        if comment:
                            # 使用注释作为摘要
                            summary_text = f"{full_name}\n{comment}"
                        else:
                            # 仅使用方法名作为摘要
                            summary_text = full_name

                    result = {
                        "id": node_id,
                        "node_type": node_type,
                        "full_name": node_info.get("full_name", ""),
                        "summary": summary_text,
                        "is_direct_match": node_info.get("is_direct_match", False)
                    }

                    # 添加相似度分数（仅对第一层节点）
                    if node_info.get("rank") is not None:
                        result["rank"] = node_info.get("rank")

                    # 添加关系信息（仅对关联节点）
                    if node_info.get("relation"):
                        result["relation"] = node_info.get("relation")

                    return result

            # 构建最终结果
            hierarchical_results = []
            for node_id, primary_node in primary_nodes.items():
                # 生成第一层节点信息
                primary_result = generate_node_summary(primary_node, node_id)

                # 生成关联节点信息
                related_results = []
                for related_node in primary_node["related_nodes"]:
                    related_result = generate_node_summary(related_node, related_node["id"])
                    related_results.append(related_result)

                primary_result["related_nodes"] = related_results
                hierarchical_results.append(primary_result)

            logger.info(f"生成层级化结果: {len(hierarchical_results)} 个第一层节点，"
                       f"总关联节点数: {sum(len(node['related_nodes']) for node in hierarchical_results)}")

            return {
                "query_text": query_text,
                "results": hierarchical_results
            }
            
        except Exception as e:
            logger.exception(f"知识图谱搜索失败: {str(e)}")
            self.error_message = f"搜索异常: {str(e)}"
            return {"error": f"知识图谱搜索失败: {str(e)}"}
    
    def get_node_content(self, node_id: str, repo_id: str = None):
        """
        获取节点内容
        
        Args:
            node_id: 节点ID
            repo_id: 仓库ID，用于过滤结果，可选
        
        Returns:
            Dict: 包含节点内容的字典
        """
        if not self.initialized:
            logger.error("知识图谱服务未初始化，无法获取节点内容")
            return {"error": "知识图谱服务未初始化"}
        
        try:
            # 构建查询条件
            where_clause = f'id(n) == "{node_id}"'
            if repo_id:
                where_clause += f''' AND (
                    CASE 
                        WHEN n.function.repo_id IS NOT NULL THEN n.function.repo_id == "{repo_id}"
                        WHEN n.annotations.repo_id IS NOT NULL THEN n.annotations.repo_id == "{repo_id}"
                        WHEN n.comment.repo_id IS NOT NULL THEN n.comment.repo_id == "{repo_id}"
                        WHEN n.repo_id IS NOT NULL THEN n.repo_id == "{repo_id}"
                        ELSE false
                    END
                )'''
                
            # 查询节点内容
            query = f"""
            MATCH (n)
            WHERE {where_clause}
            RETURN labels(n)[0] as node_type, n
            """
            
            result = self.nebula_client.execute_query(query)
            if not result or not result.rows() or len(result.rows()) == 0:
                logger.warning(f"未找到节点: ID={node_id}")
                return {"error": "未找到节点，请检查node_id或repo_id是否正确"}
            
            # 提取节点类型和属性
            record = result.rows()[0]
            node_type_value = record.values[0]
            node_value = record.values[1]
            
            # 解析节点类型
            node_type = node_type_value.value.decode('utf-8') if isinstance(node_type_value.value, bytes) else node_type_value.value
            
            # 提取节点属性
            properties = {}
            if hasattr(node_value, 'value') and hasattr(node_value.value, 'tags') and node_value.value.tags:
                for tag in node_value.value.tags:
                    tag_name = tag.name.decode('utf-8') if isinstance(tag.name, bytes) else tag.name
                    if hasattr(tag, 'props') and tag.props:
                        for key, val_obj in tag.props.items():
                            key_str = key.decode('utf-8') if isinstance(key, bytes) else key
                            if hasattr(val_obj, 'value'):
                                val = val_obj.value.decode('utf-8') if isinstance(val_obj.value, bytes) else val_obj.value
                                properties[key_str] = val
            
            # 处理content字段
            content = properties.get("content", "")
            if content:
                # 解压内容
                content = ContentProcessor.decompress(content)
                properties["content"] = content
            
            # 查询节点注释
            comment_query = f"""
            MATCH (n)-[r:documented_by]->(c:comment)
            WHERE id(n) == "{node_id}"
            {f'''AND (
                CASE 
                    WHEN n.function.repo_id IS NOT NULL THEN n.function.repo_id == "{repo_id}"
                    WHEN n.annotations.repo_id IS NOT NULL THEN n.annotations.repo_id == "{repo_id}"
                    WHEN n.class.repo_id IS NOT NULL THEN n.class.repo_id == "{repo_id}"
                    WHEN n.repo_id IS NOT NULL THEN n.repo_id == "{repo_id}"
                    ELSE false
                END
                AND
                CASE 
                    WHEN c.comment.repo_id IS NOT NULL THEN c.comment.repo_id == "{repo_id}"
                    WHEN c.repo_id IS NOT NULL THEN c.repo_id == "{repo_id}"
                    ELSE false
                END
            )''' if repo_id else ''}
            RETURN c.comment.content
            LIMIT 1
            """
            
            comment_result = self.nebula_client.execute_query(comment_query)
            if comment_result and comment_result.rows() and len(comment_result.rows()) > 0:
                comment_value = comment_result.rows()[0].values[0]
                comment_content = comment_value.value.decode('utf-8') if isinstance(comment_value.value, bytes) else comment_value.value
                
                # 解压注释内容
                if comment_content:
                    comment_content = ContentProcessor.decompress(comment_content)
                    properties["comment"] = comment_content
            
            return {
                "id": node_id,
                "node_type": node_type,
                "properties": properties
            }
            
        except Exception as e:
            logger.exception(f"获取节点内容失败: {str(e)}")
            self.error_message = f"获取内容异常: {str(e)}"
            return {"error": f"获取节点内容失败: {str(e)}"}
    
    def get_error_message(self):
        """获取最后一次错误信息"""
        return getattr(self, 'error_message', "未知错误")
    
    def close(self):
        """关闭所有连接"""
        try:
            if self.nebula_client:
                self.nebula_client.disconnect()
            if self.milvus_service:
                self.milvus_service.disconnect()
            self.initialized = False
            logger.info("知识图谱服务已关闭")
            return True
        except Exception as e:
            logger.exception(f"关闭知识图谱服务失败: {str(e)}")
            return False

# 创建默认配置的全局实例
def create_default_knowledge_graph_service():
    """创建使用默认配置的知识图谱服务实例"""
    # 从配置管理器创建Milvus配置
    connection_config = MilvusConnectionConfig(
        uri=config.milvus_config.get("uri", ""),
        token=config.milvus_config.get("token"),
        collection_name=config.milvus_config.get("collection_name", "node_vectors"),
        dimension=config.milvus_config.get("dimension")
    )

    # 创建默认字段配置（基于现有的字段结构）
    field_config = MilvusFieldConfig(
        id_field="id",
        dense_vector_field="text_dense",
        sparse_vector_field="text_sparse",
        output_fields=["id", "node_type", "full_name", "digest", "repo_id", "branch_name"]
    )

    return KnowledgeGraphService(
        milvus_connection_config=connection_config,
        milvus_field_config=field_config
    )

# 创建全局实例
knowledge_graph_service = create_default_knowledge_graph_service()